import sys

from centralized_verification.shields.slugs_shielding.combine_identical_states import load_centralized_shield, \
    load_decentralized_shield, ShieldSpec, DecentralizedShieldSpec


def calc_avg_cent_shield_actions(shield: ShieldSpec):
    total_actions = sum(len(state.actions) for state in shield.values())
    min_actions = min(len(state.actions) for state in shield.values())
    return total_actions / len(shield), min_actions


def calc_avg_dec_shield_actions(shield: DecentralizedShieldSpec):
    total_actions = 0
    min_actions = 99999999
    for shield_state in shield.values():
        actions_in_state = 0
        for agent_ordering in shield_state.action_permutations:
            actions_in_ordering = 1
            for agent in agent_ordering.actions:
                actions_in_ordering *= len(agent)
            actions_in_state += actions_in_ordering
            min_actions = min(min_actions, actions_in_ordering)

        total_actions += (actions_in_state / len(shield_state.action_permutations))

    return total_actions / len(shield), min_actions


if __name__ == '__main__':
    name = sys.argv[1]

    cent_shield = load_centralized_shield(name)
    dec_shield = load_decentralized_shield(name)

    print(f"Centralized shield actions:   {calc_avg_cent_shield_actions(cent_shield)}")
    print(f"Decentralized shield actions: {calc_avg_dec_shield_actions(dec_shield)}")
